import os
os.environ["WANDB_API_KEY"] = "xxx"
os.environ["SWANLAB_API_KEY"] = "xxx"
import sys
sys.path.append('.')
import random
import numpy as np
import jax.numpy as jnp
from absl import app, flags
import datetime
import yaml
from ml_collections import config_flags, ConfigDict
import wandb
import swanlab
from tqdm.auto import trange  # noqa
import gymnasium as gym
from env.env_list import env_list
from env.point_robot import PointRobot
from jaxrl5.wrappers import wrap_gym
from jaxrl5.agents import FISOR
from jaxrl5.data.dsrl_datasets import DSRLDataset, env2cost_dict
from jaxrl5.evaluation import evaluate, evaluate_pr
import json
from osrl.algorithms import EnsembleDynamics, EnsembleDynamicsModel, EnsembleCostModel
from osrl.common import TransitionDataset
from osrl.common.exp_util import auto_name, seed_all
from osrl.common.model_logger import Logger
from osrl.common.net import StandardScaler, SimpleScaler, termination_fn_common
from osrl.common.buffer import ReplayBuffer
import torch
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
from collections import defaultdict
from copy import deepcopy as dco



FLAGS = flags.FLAGS
flags.DEFINE_integer('env_id', 30, 'Choose env')
flags.DEFINE_integer('conservative_cost_f', 0, 'whether to use conservative cost function')
flags.DEFINE_integer('wo_reflect', 0, 'whether to use conservative cost function without reflect')
flags.DEFINE_integer('wo_conserv', 0, 'whether to use conservative cost function without conservative cost')
flags.DEFINE_integer('is_deepseek', 0, 'whether to use conservative cost function generated by deepseek')
flags.DEFINE_integer('is_gemini', 0, 'whether to use conservative cost function generated by gemini')
flags.DEFINE_integer('rollout_std_decay', 0, 'whether to use rollout std decay')
flags.DEFINE_integer('use_adamw', 0, 'whether to use AdamW optimizer')
flags.DEFINE_integer('use_unsafe_mask', 1, 'whether to use unsafe mask')
flags.DEFINE_integer('rollout_interval', 250000, 'rollout interval')
flags.DEFINE_float('ratio', 1.0, 'dataset ratio')
flags.DEFINE_float('penalty_coef', 0.5, 'uncertainty penalty coefficient')
flags.DEFINE_float('mix_ratio', 0.95, 'dataset mixing ratio')
flags.DEFINE_integer('rollout_length', 1, 'rollout length')
flags.DEFINE_integer('rollout_epochs', 10, 'rollout epochs')
flags.DEFINE_string('project', '', 'project name for wandb')
flags.DEFINE_string('device', 'cuda:0', 'cuda of pytorch')
flags.DEFINE_string('experiment_name', '', 'experiment name for wandb')
flags.DEFINE_integer('cost_limit', 10, 'cost limit')
flags.DEFINE_float('tau', 0.85, 'reverse expectile parameter')
flags.DEFINE_float('c_tau', 0.85, 'reverse expectile parameter')
flags.DEFINE_float('rollout_std', 0.1, 'rollout std')
config_flags.DEFINE_config_file(
    "config",
    None,
    "File path to the training hyperparameter configuration.",
    lock_config=False,
)

# env2dynamics = {
#     "OfflinePointButton1Gymnasium-v0": "../OSRL/logs_new/OfflinePointButton1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-364b/environment_model_safe_onlyTrue_simple_scalerTrue-364b/model",
#     "OfflinePointButton2Gymnasium-v0": "../OSRL/logs_new/OfflinePointButton2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-a4f2/environment_model_safe_onlyTrue_simple_scalerTrue-a4f2/model",
#     "OfflinePointCircle1Gymnasium-v0": "../OSRL/logs_new/OfflinePointCircle1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c4c3/environment_model_safe_onlyTrue_simple_scalerTrue-c4c3/model",
#     "OfflinePointCircle2Gymnasium-v0": "../OSRL/logs_new/OfflinePointCircle2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-2427/environment_model_safe_onlyTrue_simple_scalerTrue-2427/model",
#     "OfflinePointGoal1Gymnasium-v0": "../OSRL/logs_new/OfflinePointGoal1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-412b/environment_model_safe_onlyTrue_simple_scalerTrue-412b/model",
#     "OfflinePointGoal2Gymnasium-v0": "../OSRL/logs_new/OfflinePointGoal2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c822/environment_model_safe_onlyTrue_simple_scalerTrue-c822/model",
#     "OfflinePointPush1Gymnasium-v0": "../OSRL/logs_new/OfflinePointPush1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-9c60/environment_model_safe_onlyTrue_simple_scalerTrue-9c60/model",
#     "OfflinePointPush2Gymnasium-v0": "../OSRL/logs_new/OfflinePointPush2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-73f9/environment_model_safe_onlyTrue_simple_scalerTrue-73f9/model",
#     "OfflineCarButton1Gymnasium-v0": "../OSRL/logs_new/OfflineCarButton1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-06da/environment_model_safe_onlyTrue_simple_scalerTrue-06da/model",
#     "OfflineCarButton2Gymnasium-v0": "../OSRL/logs_new/OfflineCarButton2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-b6f2/environment_model_safe_onlyTrue_simple_scalerTrue-b6f2/model",
#     "OfflineCarCircle1Gymnasium-v0": "../OSRL/logs_new/OfflineCarCircle1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c7c6/environment_model_safe_onlyTrue_simple_scalerTrue-c7c6/model",
#     "OfflineCarCircle2Gymnasium-v0": "../OSRL/logs_new/OfflineCarCircle2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-fe46/environment_model_safe_onlyTrue_simple_scalerTrue-fe46/model",
#     "OfflineCarGoal1Gymnasium-v0": "../OSRL/logs_new/OfflineCarGoal1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-a583/environment_model_safe_onlyTrue_simple_scalerTrue-a583/model",
#     "OfflineCarGoal2Gymnasium-v0": "../OSRL/logs_new/OfflineCarGoal2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-15f6/environment_model_safe_onlyTrue_simple_scalerTrue-15f6/model",
#     "OfflineCarPush1Gymnasium-v0": "../OSRL/logs_new/OfflineCarPush1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-78a7/environment_model_safe_onlyTrue_simple_scalerTrue-78a7/model",
#     "OfflineCarPush2Gymnasium-v0": "../OSRL/logs_new/OfflineCarPush2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-3c2a/environment_model_safe_onlyTrue_simple_scalerTrue-3c2a/model",
#     'OfflineAntVelocityGymnasium-v1': "../OSRL/logs_new/OfflineAntVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-1417/environment_model_safe_onlyTrue_simple_scalerTrue-1417/model",          # 16
#     'OfflineHalfCheetahVelocityGymnasium-v1': "../OSRL/logs_new/OfflineHalfCheetahVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-0a2e/environment_model_safe_onlyTrue_simple_scalerTrue-0a2e/model",  # 17
#     'OfflineHopperVelocityGymnasium-v1': "../OSRL/logs_new/OfflineHopperVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-c2d9/environment_model_safe_onlyTrue_simple_scalerTrue-c2d9/model",       # 18
#     'OfflineSwimmerVelocityGymnasium-v1': "../OSRL/logs_new/OfflineSwimmerVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-0692/environment_model_safe_onlyTrue_simple_scalerTrue-0692/model",      # 19
#     'OfflineWalker2dVelocityGymnasium-v1': "../OSRL/logs_new/OfflineWalker2dVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-512e/environment_model_safe_onlyTrue_simple_scalerTrue-512e/model",     # 20
#     "PointRobot": "../OSRL/logs_new/PointRobot/environment_model_safe_onlyTrue_simple_scalerTrue-e67a/environment_model_safe_onlyTrue_simple_scalerTrue-e67a/model",
#     "OfflineAntCircle-v0": "../OSRL/logs_new/OfflineAntCircle-v0/environment_model_safe_onlyTrue_simple_scalerTrue-8b6e/environment_model_safe_onlyTrue_simple_scalerTrue-8b6e/model"
# }
env2dynamics = {}


def to_dict(config):
    if isinstance(config, ConfigDict):
        return {k: to_dict(v) for k, v in config.items()}
    return config

def rollout(
    init_obss, rollout_length, agent, dynamics, cost_func, exp_sigma = 0.1, use_unsafe_mask=True,
):
    # episode_rets, episode_costs, episode_lens, episode_no_safes = [], [], [], []
    rewards_arr = np.array([])
    costs_arr = np.array([])
    rollout_transitions = defaultdict(list)
    num_transitions = 0

    observations = init_obss
    unsafe_mask = None
    for _ in range(rollout_length):
        actions, agent = agent.eval_actions(observations)
        sigma = np.ones_like(actions) * exp_sigma
        actions = np.clip(np.random.normal(actions, sigma), -1.0, 1.0)
        next_observations, rewards, terminals, info = dynamics.safe_step(observations, actions, cost_func)
        rollout_transitions["observations"].append(observations)
        rollout_transitions["next_observations"].append(next_observations)
        rollout_transitions["actions"].append(actions)
        rollout_transitions["dones"].append(terminals)
        rollout_transitions["rewards"].append(rewards)
        rollout_transitions["costs"].append(info["cost"])

        num_transitions += len(observations)
        rewards_arr = np.append(rewards_arr, rewards.flatten())
        costs_arr = np.append(costs_arr, info["cost"].flatten())

        if unsafe_mask is None:
            unsafe_mask = (info["cost"] > 0)
        else:
            unsafe_mask = np.logical_or(unsafe_mask, (info["cost"] > 0))

        nonterm_mask = (~terminals).flatten()
        if nonterm_mask.sum() == 0:
            break
        observations = next_observations[nonterm_mask]
    
    unsafe_mask_ls = []
    for _ in range(rollout_length):
        unsafe_mask_ls.append(unsafe_mask)
    unsafe_mask = np.concatenate(unsafe_mask_ls, axis=0).reshape(-1,)

    for k, v in rollout_transitions.items():
        rollout_transitions[k] = np.concatenate(v, axis=0)
    if use_unsafe_mask:
        for key in rollout_transitions.keys():
            rollout_transitions[key] = rollout_transitions[key][unsafe_mask]
    return rollout_transitions, \
        {"num_transitions": rollout_transitions['observations'].shape[0], "reward_mean": rewards_arr.mean(), "cost_mean": costs_arr.mean()}

    # for _ in trange(num_episodes, desc="Evaluating", leave=False):
    #     obs, info = env.reset()
    #     episode_ret, episode_cost, episode_len= 0.0, 0.0, 0
    #     while True:
    #         if render:
    #             env.render()
    #             time.sleep(1e-3)
    #         action, agent = agent.eval_actions(obs)
    #         obs, reward, terminated, truncated, info = env.step(action)
    #         cost = info["cost"]
    #         episode_ret += reward
    #         episode_len += 1
    #         episode_cost += cost
    #         if terminated or truncated:
    #             break
    #     episode_rets.append(episode_ret)
    #     episode_lens.append(episode_len)
    #     episode_costs.append(episode_cost)

    # return {"return": np.mean(episode_rets), "episode_len": np.mean(episode_lens), "cost": np.mean(episode_costs)}


def call_main(details):
    details['agent_kwargs']['cost_scale'] = details['dataset_kwargs']['cost_scale']
    # wandb.init(project=details['project'], name=details['experiment_name'], group=details['group'], config=details['agent_kwargs'])

    if details['env_name'] == 'PointRobot':
        assert details['dataset_kwargs']['pr_data'] is not None, "No data for Point Robot"
        env = eval(details['env_name'])(id=0, seed=0)
        env_max_steps = env._max_episode_steps
        ds = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], data_location=details['dataset_kwargs']['pr_data'], cost_scale=details['dataset_kwargs']['cost_scale'],
                          safe_only=True, env_name=details['env_name'], conservative_cost_f=details['conservative_cost_f'], wo_reflect=details['wo_reflect'], wo_conserv=details['wo_conserv'],
                          is_deepseek=details['is_deepseek'], is_gemini=details['is_gemini'])
        ds_real = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], data_location=details['dataset_kwargs']['pr_data'], cost_scale=details['dataset_kwargs']['cost_scale'],
                          safe_only=True, env_name=details['env_name'], conservative_cost_f=details['conservative_cost_f'], wo_reflect=details['wo_reflect'], wo_conserv=details['wo_conserv'],
                          is_deepseek=details['is_deepseek'], is_gemini=details['is_gemini'])
    else:
        env = gym.make(details['env_name'])
        # org_dataset = env.get_dataset()
        separate_buffer = False
        if details['rollout_length'] > 1:
            separate_buffer = True
        ds = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], cost_scale=details['dataset_kwargs']['cost_scale'], ratio=details['ratio'], 
                         safe_only=True, env_name=details['env_name'], conservative_cost_f=details['conservative_cost_f'], wo_reflect=details['wo_reflect'], wo_conserv=details['wo_conserv'],
                         is_deepseek=details['is_deepseek'], is_gemini=details['is_gemini'], separate_buffer=separate_buffer)
        ds_real = DSRLDataset(env, critic_type=details['agent_kwargs']['critic_type'], cost_scale=details['dataset_kwargs']['cost_scale'], ratio=details['ratio'], 
                         safe_only=True, env_name=details['env_name'], conservative_cost_f=details['conservative_cost_f'], wo_reflect=details['wo_reflect'], wo_conserv=details['wo_conserv'],
                         is_deepseek=details['is_deepseek'], is_gemini=details['is_gemini'])
        env_max_steps = env._max_episode_steps
        env = wrap_gym(env, cost_limit=details['cost_limit'])
        if not separate_buffer:
            ds.normalize_returns(env.max_episode_reward, env.min_episode_reward, env_max_steps)
        ds_real.normalize_returns(env.max_episode_reward, env.min_episode_reward, env_max_steps)
    ds.seed(details["seed"])
    ds_real.seed(details["seed"])

    config_dict = dict(details['agent_kwargs'])
    config_dict['env_max_steps'] = env_max_steps

    model_cls = config_dict.pop("model_cls") 
    config_dict.pop("cost_scale") 
    agent = globals()[model_cls].create(
        details['seed'], env.observation_space, env.action_space, **config_dict
    )

    swanlab.init(project=details['project'], experiment_name=details['experiment_name'], config=config_dict, mode="local")
    # swanlab.init(project=details['project'], experiment_name=details['experiment_name'], config=config_dict)


    dynamics_model = EnsembleDynamicsModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=details['dynamics_hidden_dims'],
        num_ensemble=details['num_ensemble'],
        num_elites=details['num_elites'],
        weight_decays=details['dynamic_weight_decays'],
        with_cost=details['with_cost'],
        device=details['device']
    )
    cost_model = None
    dynamics_optim = torch.optim.Adam(
        dynamics_model.parameters(),
        lr=details['dynamics_lr']
    )
    cost_model_optim = None
    dynamics_scheduler = None
    cost_model_scheduler = None
    if details['simple_scaler']:
        scaler = SimpleScaler()
    else:
         scaler = StandardScaler()
    termination_fn = termination_fn_common
    dynamics = EnsembleDynamics(
        dynamics_model,
        cost_model,
        dynamics_optim,
        cost_model_optim,
        scaler,
        termination_fn,
        use_scheduler=details['use_scheduler'],
        dynamics_scheduler=dynamics_scheduler,
        cost_model_scheduler=cost_model_scheduler,
        penalty_coef=details['penalty_coef'],
        with_cost=details['with_cost'],
        use_delta_obs=details['use_delta_obs'],
        reward_scale=details['reward_scale'],
        cost_scale=details['cost_scale'],
        cost_coef=details['cost_coef']
    )
    dynamics.load(env2dynamics[details['env_name']])

    # fake_buffer = ReplayBuffer(
    #     buffer_size=details['model_buffer_size'],
    #     obs_shape=env.observation_space.shape,
    #     obs_dtype=np.float32,
    #     action_dim=env.action_space.shape[0],
    #     action_dtype=np.float32,
    #     device=details['device']
    # )


    save_time = 1
    for i in trange(details['max_steps'], smoothing=0.1, desc=details['experiment_name']):

        ### model rollout
        if i % details['rollout_interval'] == 0:
            rollout_std = details['rollout_std']
            if details['rollout_std_decay']:
                rollout_std = details['rollout_std'] * (1 - i / details['max_steps'])
            for _ in range(details['rollout_epochs']):
                init_data = ds_real.sample(details['rollout_batch_size'])
                init_obss = init_data['observations']
                rollout_transitions, rollout_info = rollout(init_obss, details['rollout_length'], agent, dynamics, env2cost_dict[details['env_name']], exp_sigma=rollout_std, use_unsafe_mask=details['use_unsafe_mask'])
                # fake_buffer.add_batch(rollout_transitions['observations'], rollout_transitions['next_observations'], rollout_transitions['actions'], rollout_transitions['rewards'], rollout_transitions['costs'], rollout_transitions['dones'])
                swanlab.log({f"rollout/{k}": v for k, v in rollout_info.items()}, step=i)
                # all_fake_data = fake_buffer.sample_all()
                if details['env_name'] != "PointRobot":
                    ds.add_and_norm(rollout_transitions, env.max_episode_reward, env.min_episode_reward, env_max_steps)
                else:
                    ds.add(rollout_transitions)



        ### policy train
        # if details['rollout_length'] > 1:
        #     critic_bs = 256
        #     other_bs = details['batch_size'] - critic_bs
        #     real_bs_c = int(critic_bs*details['mix_ratio'])
        #     fake_bs_c = critic_bs - real_bs_c
        #     real_bs_o = int(other_bs*details['mix_ratio'])
        #     fake_bs_o = other_bs - real_bs_o
        #     sample_real_c = ds_real.sample_jax(real_bs_c)
        #     sample_fake_c = ds.sample_jax(fake_bs_c)
        #     sample_real_o = ds_real.sample_jax(real_bs_o)
        #     sample_fake_o = ds.sample_jax(fake_bs_o)
        #     sample = {k: jnp.concatenate([sample_real_c[k], sample_fake_c[k], sample_real_o[k], sample_fake_o[k]], axis=0) for k in sample_real_c.keys()}
        #     real_sample = ds_real.sample_jax(details['batch_size'])
        # else:
        # sample = ds.sample_jax(details['batch_size'])
        sample = ds.sample_jax(256)
        real_sample = ds_real.sample_jax(details['batch_size'])
        agent, info = agent.update(sample, real_sample)
        
        if i % details['log_interval'] == 0:
            swanlab.log({f"train/{k}": v for k, v in info.items()}, step=i)

        # if i % details['eval_interval'] == 0 and i > 0:
        if i % details['eval_interval'] == 0:
            agent.save(f"./results/{details['group']}/{details['experiment_name']}", save_time)
            save_time += 1
            if details['env_name'] == 'PointRobot':
                eval_info = evaluate_pr(agent, env, details['eval_episodes'])
            else:
                eval_info = evaluate(agent, env, details['eval_episodes'])
            if details['env_name'] != 'PointRobot':
                eval_info["normalized_return"], eval_info["normalized_cost"] = env.get_normalized_score(eval_info["return"], eval_info["cost"])
            swanlab.log({f"eval/{k}": v for k, v in eval_info.items()}, step=i)


def main(_):
    parameters = FLAGS.config
    if FLAGS.project != '':
        parameters['project'] = FLAGS.project
    print(parameters)
    parameters['env_name'] = env_list[FLAGS.env_id]
    parameters['ratio'] = FLAGS.ratio
    parameters['conservative_cost_f'] = FLAGS.conservative_cost_f
    parameters['wo_reflect'] = FLAGS.wo_reflect
    parameters['wo_conserv'] = FLAGS.wo_conserv
    parameters['is_deepseek'] = FLAGS.is_deepseek
    parameters['is_gemini'] = FLAGS.is_gemini
    parameters['cost_limit'] = FLAGS.cost_limit
    parameters['agent_kwargs']['cost_limit'] = FLAGS.cost_limit
    parameters['group'] = parameters['env_name']
    # assert False

    parameters['agent_kwargs']['critic_hyperparam'] = FLAGS.tau
    parameters['agent_kwargs']['cost_critic_hyperparam'] = FLAGS.c_tau
    parameters['device'] = FLAGS.device
    parameters['penalty_coef'] = FLAGS.penalty_coef
    parameters['mix_ratio'] = FLAGS.mix_ratio
    parameters['use_unsafe_mask'] = FLAGS.use_unsafe_mask
    parameters['rollout_std'] = FLAGS.rollout_std
    parameters['rollout_length'] = FLAGS.rollout_length
    parameters['rollout_epochs'] = FLAGS.rollout_epochs
    parameters['rollout_interval'] = FLAGS.rollout_interval
    parameters['rollout_std_decay'] = FLAGS.rollout_std_decay
    parameters['use_adamw'] = FLAGS.use_adamw

    if FLAGS.conservative_cost_f:
        pre = "safer_llm_"
    else:
        pre = ""

    parameters['experiment_name'] = pre + parameters['env_name']
    parameters['experiment_name'] += '_' + str(datetime.date.today()) + '_s' + str(parameters['seed']) + '_rl' + str(parameters['rollout_length']) + '_ri' + str(parameters['rollout_interval'])+ '_tau' + str(FLAGS.tau) 

    # if parameters['env_name'] == 'PointRobot':
        # parameters['max_steps'] = 100001
        # parameters['batch_size'] = 1024
        # parameters['eval_interval'] = 25000
        # parameters['agent_kwargs']['cost_temperature'] = 2
        # parameters['agent_kwargs']['reward_temperature'] = 5
        # parameters['agent_kwargs']['cost_ub'] = 150
        # parameters['agent_kwargs']['N'] = 8

    print(parameters)

    if not os.path.exists(f"./results/{parameters['group']}/{parameters['experiment_name']}"):
        os.makedirs(f"./results/{parameters['group']}/{parameters['experiment_name']}")
    with open(f"./results/{parameters['group']}/{parameters['experiment_name']}/config.json", "w") as f:
        json.dump(to_dict(parameters), f, indent=4)
    
    call_main(parameters)


if __name__ == '__main__':
    app.run(main)
